#!/usr/bin/env python3

# pyre-strict

import warnings
from typing import Callable, Optional, Tuple

import torch
from torch import Tensor


def _divide_and_aggregate_metrics(
    inputs: Tuple[Tensor, ...],
    n_perturb_samples: int,
    # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
    metric_func: Callable,
    # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
    agg_func: Callable = torch.add,
    max_examples_per_batch: Optional[int] = None,
) -> Tensor:
    r"""
    This function is used to slice large number of samples `n_perturb_samples` per
    input example into smaller pieces, computing the metrics for each small piece and
    aggregating the results across all `n_perturb_samples` per example. The function
    returns overall aggregated metric per sample. The size of each slice is determined
    by the `max_examples_per_batch` input parameter.

    Args:

        inputs (tuple): The original inputs formatted in a tuple that are passed to
                        the metrics function and that are used to compute the
                        attributions for.
        n_perturb_samples (int): The number of samples per example that are used for
                        perturbation purposes for example.
        metric_func (Callable): This function takes the number of samples per
                        input batch and returns an overall metric for each example.
        agg_func (Callable, optional): This function is used to aggregate the
                        metrics across multiple sub-batches and that are
                        generated by `metric_func`.
        max_examples_per_batch (int, optional): The maximum number of allowed examples
                        per batch.

        Returns:

            metric (Tensor): A metric score estimated by `metric_func` per
                        input example.
    """
    bsz = inputs[0].size(0)

    if max_examples_per_batch is not None and (
        max_examples_per_batch // bsz < 1
        or max_examples_per_batch // bsz > n_perturb_samples
    ):
        warnings.warn(
            (
                "`max_examples_per_batch` must be at least equal to the"
                " input batch size and at most to "
                "`input batch size` * `n_perturb_samples`."
                "`max_examples_per_batch` is: {} and the input batch size is: {}."
                "This is necessary because we require that each sub-batch that is used "
                "to compute the metrics, contains at least an instance of "
                "the original example and doesn't exceed the number of "
                "expanded n_perturb_samples."
            ).format(max_examples_per_batch, bsz),
            stacklevel=1,
        )

    max_inps_per_batch = (
        n_perturb_samples
        if max_examples_per_batch is None
        else min(max(max_examples_per_batch // bsz, 1), n_perturb_samples)
    )

    current_n_steps = max_inps_per_batch

    metrics_sum = metric_func(max_inps_per_batch)

    while current_n_steps < n_perturb_samples:
        current_n_steps += max_inps_per_batch

        metric = metric_func(
            max_inps_per_batch
            if current_n_steps <= n_perturb_samples
            else max_inps_per_batch - (current_n_steps - n_perturb_samples)
        )

        current_n_steps = min(current_n_steps, n_perturb_samples)

        metrics_sum = agg_func(metrics_sum, metric)
    return metrics_sum
